{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `clip_retrieval.clip_client`\n",
    "\n",
    "This python module allows you to query a backend remote via its exposed REST api."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uT9FwUjk_lRD"
   },
   "source": [
    "## Install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: clip-retrieval in /home/rom1504/clip-retrieval (2.33.0)\n",
      "Requirement already satisfied: img2dataset in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (1.25.5)\n",
      "Requirement already satisfied: aiohttp<4,>=3.8.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (3.8.1)\n",
      "Requirement already satisfied: autofaiss<3,>=2.9.6 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (2.9.8)\n",
      "Requirement already satisfied: clip-anytorch<3,>=2.3.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (2.3.1)\n",
      "Requirement already satisfied: faiss-cpu<2,>=1.7.2 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (1.7.2)\n",
      "Requirement already satisfied: fire<0.5.0,>=0.4.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (0.4.0)\n",
      "Requirement already satisfied: flask<3,>=2.0.3 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (2.0.3)\n",
      "Requirement already satisfied: flask_cors<4,>=3.0.10 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (3.0.10)\n",
      "Requirement already satisfied: flask_restful<1,>=0.3.9 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (0.3.9)\n",
      "Requirement already satisfied: fsspec==2022.1.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (2022.1.0)\n",
      "Requirement already satisfied: h5py<4,>=3.1.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (3.6.0)\n",
      "Requirement already satisfied: multilingual-clip<2,>=1.0.10 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (1.0.10)\n",
      "Requirement already satisfied: numpy<2,>=1.19.5 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (1.22.2)\n",
      "Requirement already satisfied: open-clip-torch<2.0.0,>=1.0.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (1.2.0)\n",
      "Requirement already satisfied: pandas<2,>=1.1.5 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (1.4.0)\n",
      "Requirement already satisfied: prometheus-client<1,>=0.13.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (0.13.1)\n",
      "Requirement already satisfied: pyarrow<8,>=6.0.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (7.0.0)\n",
      "Requirement already satisfied: requests<3,>=2.27.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (2.27.1)\n",
      "Requirement already satisfied: sentence-transformers<3,>=2.2.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (2.2.0)\n",
      "Requirement already satisfied: torch<2,>=1.7.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (1.10.2+cu113)\n",
      "Requirement already satisfied: torchvision<2,>=0.10.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (0.11.3+cu113)\n",
      "Requirement already satisfied: tqdm<5,>=4.62.3 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (4.62.3)\n",
      "Requirement already satisfied: wandb<0.13,>=0.12.10 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (0.12.10)\n",
      "Requirement already satisfied: webdataset<0.2,>=0.1.103 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-retrieval) (0.1.103)\n",
      "Requirement already satisfied: exifread-nocycle<4,>=3.0.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from img2dataset) (3.0.1)\n",
      "Requirement already satisfied: albumentations<2,>=1.1.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from img2dataset) (1.1.0)\n",
      "Requirement already satisfied: opencv-python<5,>=4.5.5.62 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from img2dataset) (4.5.5.62)\n",
      "Requirement already satisfied: dataclasses<1.0.0,>=0.6 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from img2dataset) (0.6)\n",
      "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from aiohttp<4,>=3.8.1->clip-retrieval) (2.0.11)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from aiohttp<4,>=3.8.1->clip-retrieval) (1.2.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from aiohttp<4,>=3.8.1->clip-retrieval) (1.3.0)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from aiohttp<4,>=3.8.1->clip-retrieval) (21.4.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from aiohttp<4,>=3.8.1->clip-retrieval) (6.0.2)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from aiohttp<4,>=3.8.1->clip-retrieval) (1.7.2)\n",
      "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from aiohttp<4,>=3.8.1->clip-retrieval) (4.0.2)\n",
      "Requirement already satisfied: PyYAML in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from albumentations<2,>=1.1.0->img2dataset) (6.0)\n",
      "Requirement already satisfied: qudida>=0.0.4 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from albumentations<2,>=1.1.0->img2dataset) (0.0.4)\n",
      "Requirement already satisfied: scipy in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from albumentations<2,>=1.1.0->img2dataset) (1.8.0)\n",
      "Requirement already satisfied: opencv-python-headless>=4.1.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from albumentations<2,>=1.1.0->img2dataset) (4.5.5.62)\n",
      "Requirement already satisfied: scikit-image>=0.16.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from albumentations<2,>=1.1.0->img2dataset) (0.19.1)\n",
      "Requirement already satisfied: ftfy in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-anytorch<3,>=2.3.1->clip-retrieval) (6.0.3)\n",
      "Requirement already satisfied: regex in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from clip-anytorch<3,>=2.3.1->clip-retrieval) (2022.1.18)\n",
      "Requirement already satisfied: termcolor in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from fire<0.5.0,>=0.4.0->clip-retrieval) (1.1.0)\n",
      "Requirement already satisfied: six in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from fire<0.5.0,>=0.4.0->clip-retrieval) (1.16.0)\n",
      "Requirement already satisfied: Werkzeug>=2.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from flask<3,>=2.0.3->clip-retrieval) (2.0.2)\n",
      "Requirement already satisfied: itsdangerous>=2.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from flask<3,>=2.0.3->clip-retrieval) (2.0.1)\n",
      "Requirement already satisfied: click>=7.1.2 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from flask<3,>=2.0.3->clip-retrieval) (8.0.3)\n",
      "Requirement already satisfied: Jinja2>=3.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from flask<3,>=2.0.3->clip-retrieval) (3.0.3)\n",
      "Requirement already satisfied: pytz in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from flask_restful<1,>=0.3.9->clip-retrieval) (2021.3)\n",
      "Requirement already satisfied: aniso8601>=0.82 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from flask_restful<1,>=0.3.9->clip-retrieval) (9.0.1)\n",
      "Requirement already satisfied: transformers in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from multilingual-clip<2,>=1.0.10->clip-retrieval) (4.19.2)\n",
      "Requirement already satisfied: python-dateutil>=2.8.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from pandas<2,>=1.1.5->clip-retrieval) (2.8.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from requests<3,>=2.27.1->clip-retrieval) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from requests<3,>=2.27.1->clip-retrieval) (1.26.8)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from requests<3,>=2.27.1->clip-retrieval) (2021.10.8)\n",
      "Requirement already satisfied: huggingface-hub in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from sentence-transformers<3,>=2.2.0->clip-retrieval) (0.4.0)\n",
      "Requirement already satisfied: nltk in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from sentence-transformers<3,>=2.2.0->clip-retrieval) (3.6.7)\n",
      "Requirement already satisfied: scikit-learn in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from sentence-transformers<3,>=2.2.0->clip-retrieval) (1.0.2)\n",
      "Requirement already satisfied: sentencepiece in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from sentence-transformers<3,>=2.2.0->clip-retrieval) (0.1.96)\n",
      "Requirement already satisfied: typing-extensions in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from torch<2,>=1.7.1->clip-retrieval) (4.0.1)\n",
      "Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from torchvision<2,>=0.10.1->clip-retrieval) (9.0.1)\n",
      "Requirement already satisfied: GitPython>=1.0.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from wandb<0.13,>=0.12.10->clip-retrieval) (3.1.26)\n",
      "Requirement already satisfied: yaspin>=1.0.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from wandb<0.13,>=0.12.10->clip-retrieval) (2.1.0)\n",
      "Requirement already satisfied: psutil>=5.0.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from wandb<0.13,>=0.12.10->clip-retrieval) (5.9.0)\n",
      "Requirement already satisfied: pathtools in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from wandb<0.13,>=0.12.10->clip-retrieval) (0.1.2)\n",
      "Requirement already satisfied: protobuf>=3.12.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from wandb<0.13,>=0.12.10->clip-retrieval) (3.19.4)\n",
      "Requirement already satisfied: promise<3,>=2.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from wandb<0.13,>=0.12.10->clip-retrieval) (2.3)\n",
      "Requirement already satisfied: docker-pycreds>=0.4.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from wandb<0.13,>=0.12.10->clip-retrieval) (0.4.0)\n",
      "Requirement already satisfied: shortuuid>=0.5.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from wandb<0.13,>=0.12.10->clip-retrieval) (1.0.8)\n",
      "Requirement already satisfied: sentry-sdk>=1.0.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from wandb<0.13,>=0.12.10->clip-retrieval) (1.5.4)\n",
      "Requirement already satisfied: braceexpand in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from webdataset<0.2,>=0.1.103->clip-retrieval) (0.1.7)\n",
      "Requirement already satisfied: gitdb<5,>=4.0.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from GitPython>=1.0.0->wandb<0.13,>=0.12.10->clip-retrieval) (4.0.9)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from Jinja2>=3.0->flask<3,>=2.0.3->clip-retrieval) (2.0.1)\n",
      "Requirement already satisfied: imageio>=2.4.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations<2,>=1.1.0->img2dataset) (2.14.1)\n",
      "Requirement already satisfied: networkx>=2.2 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations<2,>=1.1.0->img2dataset) (2.6.3)\n",
      "Requirement already satisfied: tifffile>=2019.7.26 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations<2,>=1.1.0->img2dataset) (2022.2.2)\n",
      "Requirement already satisfied: PyWavelets>=1.1.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations<2,>=1.1.0->img2dataset) (1.2.0)\n",
      "Requirement already satisfied: packaging>=20.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from scikit-image>=0.16.1->albumentations<2,>=1.1.0->img2dataset) (21.3)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from scikit-learn->sentence-transformers<3,>=2.2.0->clip-retrieval) (3.1.0)\n",
      "Requirement already satisfied: joblib>=0.11 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from scikit-learn->sentence-transformers<3,>=2.2.0->clip-retrieval) (1.1.0)\n",
      "Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from transformers->multilingual-clip<2,>=1.0.10->clip-retrieval) (0.11.4)\n",
      "Requirement already satisfied: filelock in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from transformers->multilingual-clip<2,>=1.0.10->clip-retrieval) (3.4.2)\n",
      "Requirement already satisfied: wcwidth in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from ftfy->clip-anytorch<3,>=2.3.1->clip-retrieval) (0.2.5)\n",
      "Requirement already satisfied: smmap<6,>=3.0.1 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython>=1.0.0->wandb<0.13,>=0.12.10->clip-retrieval) (5.0.0)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/rom1504/clip-retrieval/.env/lib/python3.8/site-packages (from packaging>=20.0->scikit-image>=0.16.1->albumentations<2,>=1.1.0->img2dataset) (3.0.7)\n",
      "\u001b[33mWARNING: You are using pip version 22.0.3; however, version 22.1.2 is available.\n",
      "You should consider upgrading via the '/home/rom1504/clip-retrieval/.env/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install clip-retrieval img2dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "from clip_retrieval.clip_client import ClipClient, Modality\n",
    "\n",
    "IMAGE_BASE_URL = \"https://github.com/rom1504/clip-retrieval/raw/main/tests/test_clip_inference/test_images/\"\n",
    "\n",
    "def log_result(result):\n",
    "    id, caption, url, similarity = result[\"id\"], result[\"caption\"], result[\"url\"], result[\"similarity\"]\n",
    "    print(f\"id: {id}\")\n",
    "    print(f\"caption: {caption}\")\n",
    "    print(f\"url: {url}\")\n",
    "    print(f\"similarity: {similarity}\")\n",
    "    display(Image(url=url, unconfined=True))\n",
    "\n",
    "client = ClipClient(\n",
    "    url=\"https://knn.laion.ai/knn-service\",\n",
    "    indice_name=\"laion5B-L-14\",\n",
    "    aesthetic_score=9,\n",
    "    aesthetic_weight=0.5,\n",
    "    modality=Modality.IMAGE,\n",
    "    num_images=10,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Query by text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "id: 518836491\n",
      "caption: orange cat with supicious look stock photo\n",
      "url: https://media.istockphoto.com/photos/orange-cat-with-supicious-look-picture-id907595140?k=6&amp;m=907595140&amp;s=612x612&amp;w=0&amp;h=4CTvSxNvv4sxSCPxViryha4kAjuxDbrXM5vy4VPOuzk=\n",
      "similarity: 0.5591723918914795\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://media.istockphoto.com/photos/orange-cat-with-supicious-look-picture-id907595140?k=6&amp;m=907595140&amp;s=612x612&amp;w=0&amp;h=4CTvSxNvv4sxSCPxViryha4kAjuxDbrXM5vy4VPOuzk=\" class=\"unconfined\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "cat_results = client.query(text=\"an image of a cat\")\n",
    "log_result(cat_results[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Query by image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "id: 574870177\n",
      "caption: Palm trees in Orlando, Florida\n",
      "url: https://www.polefitfreedom.com/wp-content/uploads/2018/03/Orlando.jpg\n",
      "similarity: 0.9619368314743042\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://www.polefitfreedom.com/wp-content/uploads/2018/03/Orlando.jpg\" class=\"unconfined\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "beach_results = client.query(image=\"https://github.com/rom1504/clip-retrieval/raw/main/tests/test_clip_inference/test_images/321_421.jpg\")\n",
    "log_result(beach_results[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Query by embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "import clip  # pylint: disable=import-outside-toplevel\n",
    "import torch\n",
    "\n",
    "model, preprocess = clip.load(\"ViT-L/14\", device=\"cpu\", jit=True)\n",
    "\n",
    "import urllib\n",
    "import io\n",
    "import numpy as np\n",
    "def download_image(url):\n",
    "    urllib_request = urllib.request.Request(\n",
    "        url,\n",
    "        data=None,\n",
    "        headers={\"User-Agent\": \"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0\"},\n",
    "    )\n",
    "    with urllib.request.urlopen(urllib_request, timeout=10) as r:\n",
    "        img_stream = io.BytesIO(r.read())\n",
    "    return img_stream\n",
    "def normalized(a, axis=-1, order=2):\n",
    "    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))\n",
    "    l2[l2 == 0] = 1\n",
    "    return a / np.expand_dims(l2, axis)\n",
    "\n",
    "def get_text_emb(text):\n",
    "    with torch.no_grad():\n",
    "        text_emb = model.encode_text(clip.tokenize([text], truncate=True).to(\"cpu\"))\n",
    "        text_emb /= text_emb.norm(dim=-1, keepdim=True)\n",
    "        text_emb = text_emb.cpu().detach().numpy().astype(\"float32\")[0]\n",
    "    return text_emb\n",
    "\n",
    "from PIL import Image as pimage\n",
    "\n",
    "def get_image_emb(image_url):\n",
    "    with torch.no_grad():\n",
    "        image = pimage.open(download_image(image_url))\n",
    "        image_emb = model.encode_image(preprocess(image).unsqueeze(0).to(\"cpu\"))\n",
    "        image_emb /= image_emb.norm(dim=-1, keepdim=True)\n",
    "        image_emb = image_emb.cpu().detach().numpy().astype(\"float32\")[0]\n",
    "        return image_emb\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "id: 3142652315\n",
      "caption: KCTS29 100% Cotton Unisex Kids T Shirt Crew Neck V Neck Solid Maroon\n",
      "url: https://cdn.shopify.com/s/files/1/1531/9423/products/CTS29_100_Cotton_T_Shirt_Crew_Neck_V_Neck_Long_Sleeves_Solid_Maroon_27f4c580-c286-4e20-b7d7-110c4b2a3ff3_large.jpg?v=1476875623\n",
      "similarity: 0.5375761389732361\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://cdn.shopify.com/s/files/1/1531/9423/products/CTS29_100_Cotton_T_Shirt_Crew_Neck_V_Neck_Long_Sleeves_Solid_Maroon_27f4c580-c286-4e20-b7d7-110c4b2a3ff3_large.jpg?v=1476875623\" class=\"unconfined\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "red_tshirt_text_emb =  get_text_emb(\"red tshirt\")\n",
    "red_tshirt_results = client.query(embedding_input=red_tshirt_text_emb.tolist())\n",
    "log_result(red_tshirt_results[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "id: 2463946620\n",
      "caption: 8c7889e0b92b Cinderella Divine 1295 Long Chiffon Grecian Royal Blue Dress Mid Length  Sleeves V Neck ...\n",
      "url: https://cdn.shopify.com/s/files/1/1417/0920/products/1295cd-royal-blue_cfcbd4bc-ed74-47c0-8659-c1b8691990df.jpg?v=1527650905\n",
      "similarity: 0.9430060386657715\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://cdn.shopify.com/s/files/1/1417/0920/products/1295cd-royal-blue_cfcbd4bc-ed74-47c0-8659-c1b8691990df.jpg?v=1527650905\" class=\"unconfined\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "blue_dress_image_emb = get_image_emb(\"https://rukminim1.flixcart.com/image/612/612/kv8fbm80/dress/b/5/n/xs-b165-royal-blue-babiva-fashion-original-imag86psku5pbx2g.jpeg?q=70\")\n",
    "blue_dress_results = client.query(embedding_input=blue_dress_image_emb.tolist())\n",
    "log_result(blue_dress_results[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "id: 2702080924\n",
      "caption: CLEARANCE - Long Chiffon Grecian Red Dress Mid Length Sleeves V Neck (Size Medium)\n",
      "url: https://cdn-img-3.wanelo.com/p/716/c27/0c0/aef7a32a4317370b6f7f14b/x354-q80.jpg\n",
      "similarity: 0.824600338935852\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://cdn-img-3.wanelo.com/p/716/c27/0c0/aef7a32a4317370b6f7f14b/x354-q80.jpg\" class=\"unconfined\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "red_tshirt_text_emb =  get_text_emb(\"red tshirt\")\n",
    "blue_dress_image_emb = get_image_emb(\"https://rukminim1.flixcart.com/image/612/612/kv8fbm80/dress/b/5/n/xs-b165-royal-blue-babiva-fashion-original-imag86psku5pbx2g.jpeg?q=70\")\n",
    "mean_emb = normalized(red_tshirt_text_emb + blue_dress_image_emb)[0]\n",
    "mean_results = client.query(embedding_input=mean_emb.tolist())\n",
    "log_result(mean_results[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download and format a dataset from the results of a query\n",
    "\n",
    "If you have some images of your own, you can query each one and collect the results into a custom dataset (a small subset of LAION-5B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:17<00:00,  2.44s/it]\n"
     ]
    }
   ],
   "source": [
    "# Create urls from known images in repo\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "test_images = [f\"{IMAGE_BASE_URL}{image}\" for image in [\"123_456.jpg\", \"208_495.jpg\", \"321_421.jpg\", \"389_535.jpg\", \"416_264.jpg\", \"456_123.jpg\", \"524_316.jpg\"]]\n",
    "\n",
    "# Re-initialize client with higher num_images\n",
    "client = ClipClient(url=\"https://knn.laion.ai/knn-service\", indice_name=\"laion5B-L-14\", num_images=40)\n",
    "\n",
    "# Run one query per image\n",
    "combined_results = []\n",
    "for image in tqdm(test_images):\n",
    "    combined_results.extend(client.query(image=image))\n",
    "\n",
    "# Save results to json file\n",
    "with open(\"search-results.json\", \"w\") as f:\n",
    "    json.dump(combined_results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting the downloading of this file\n",
      "Sharding file number 1 of 1 called /home/samsepiol/Projects/clip-retrieval/notebook/search-results.json\n",
      "0it [00:00, ?it/s]File sharded in 1 shards\n",
      "Downloading starting now, check your bandwidth speed (with bwm-ng)your cpu (with htop), and your disk usage (with iotop)!\n",
      "1it [00:32, 32.08s/it]\n",
      "worker  - success: 0.921 - failed to download: 0.079 - failed to resize: 0.000 - images per sec: 2 - count: 76\n",
      "total   - success: 0.921 - failed to download: 0.079 - failed to resize: 0.000 - images per sec: 2 - count: 76\n"
     ]
    }
   ],
   "source": [
    "!img2dataset \"search-results.json\" --input_format=\"json\" --caption_col \"caption\" --output_folder=\"laion-enhanced-dataset\" --resize_mode=\"no\" --output_format=\"files\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done! Download/copy the contents of /home/samsepiol/Projects/clip-retrieval/notebook/laion-enhanced-images/\n",
      "/home/samsepiol/Projects/clip-retrieval/notebook/laion-enhanced-images/00000.tar\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "print(f\"Done! Download/copy the contents of {os.getcwd()}/laion-enhanced-dataset/\")\n",
    "!realpath laion-enhanced-dataset/"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "8890164936ba431effa62f548d2e190a63033d8c51925a70e93a060bef4e9d5d"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 ('.env': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
